{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools \n",
    "import math\n",
    "\n",
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import torch.optim as optim\n",
    "\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sparse CNN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The idea of the sparse CNN is to utilize only the important values of its filters and set those less impotant to zero.\n",
    "\n",
    "For terminology, we'll call the weights of a specific channel a \"filter\", and we'll call the values of each filter a \"connection\".\n",
    "Via Hebian inspired learning, those connections that see high absolute-values in both their inputs and outputs will be favored.\n",
    "This commparison of input and output will be refered to as the connection's strength."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Connection Strength Calculations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize a typical conv layer to start.\n",
    "From here, three seperate methods will be attempted to calculate the strength of the connections."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "c_out = 7\n",
    "c_in = 8\n",
    "conv = torch.nn.Conv2d(c_in, c_out, kernel_size=(2, 2), stride=(1, 1), padding=0, dilation=1, groups=1)\n",
    "input_tensor = torch.randn(2, c_in, 5, 3)\n",
    "output_tensor = conv(input_tensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get params of conv layer.\n",
    "in_channels = conv.in_channels  # Number of channels in the input image\n",
    "out_channels = conv.out_channels  # Number of channels produced by the convolution\n",
    "kernel_size = conv.kernel_size  # Size of the convolving kernel\n",
    "stride = conv.stride  # Stride of the convolution. Default: 1\n",
    "padding = conv.padding  # Zero-padding added to both sides of the input. Default: 0\n",
    "padding_mode = conv.padding_mode  # zeros\n",
    "dilation = conv.dilation  # Spacing between kernel elements. Default: 1\n",
    "groups = conv.groups  # Number of blocked connections from input channels to output channels. Default: 1\n",
    "bias = conv.bias is not None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Vectorized Method"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The idea of this method is to utilize convolutional arithmetic to determine the input for a given output unit and a given connection.\n",
    "\n",
    "Suppose we initialize a weight matrix of exactly the same dimensions of our orignal conv layer, and set all of filters to 0 except for one connection. That is,\n",
    "```\n",
    "new_conv.weight[:, c, j, h] = 1.\n",
    "```\n",
    "Now if we pass the input through `new_conv`, we'll be given an output tensor of the same size as the original, but with the input values arranged at the locations of their respective output through the connection. That is,\n",
    "```\n",
    "old_output = conv[input]\n",
    "new_output = new_conv[input]\n",
    "\n",
    "# ==> for all b, j,  and h (b being the batch), we have \n",
    "# new_output[b, :, j, h] = input[<indices of input passed through connection conv.weight[:, c, j, h]>]\n",
    "\n",
    "examine_connections(old_output, new_output) # done in pointwise fashion\n",
    "```\n",
    "With this vectorized calculation, we may then loop over all combinations of `c`, `j`, and `h`, compare the outputs to their respective inputs, and populate a matrix to record the strengths."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_single_unit_conv(c, j, h, **kwargs):\n",
    "    \"\"\"\n",
    "    Constructs and returns conv layer with trainging disabled and\n",
    "    all zero weights except along the output channels for unit\n",
    "    specified as (c, j, h).\n",
    "    \"\"\"\n",
    "    \n",
    "    # Construct conv.\n",
    "    conv = torch.nn.Conv2d(**kwargs)\n",
    "    \n",
    "    # Turn off training.\n",
    "    conv.train = False\n",
    "    \n",
    "    # Set weights to zero except those specified.\n",
    "    with torch.no_grad():\n",
    "        conv.weight.set_(torch.zeros_like(conv.weight))\n",
    "        conv.weight[:, c, j, h] = 1\n",
    "        \n",
    "    return conv\n",
    "\n",
    "# Get inidices that loop over all connections.\n",
    "single_unit_indxs = list(itertools.product(*[range(d) for d in conv.weight.shape[1:]]))\n",
    "\n",
    "single_unit_convs = [\n",
    "    get_single_unit_conv(\n",
    "        c, j, h,\n",
    "        in_channels=in_channels,\n",
    "        out_channels=out_channels,\n",
    "        kernel_size=kernel_size,\n",
    "        stride=stride,\n",
    "        padding=padding,\n",
    "        padding_mode=padding_mode,\n",
    "        dilation=dilation,\n",
    "        groups=groups,\n",
    "        bias=False,\n",
    "    )\n",
    "    for c, j, h in single_unit_indxs\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f1():\n",
    "    \"\"\"\n",
    "    Calculate connection strengths.\n",
    "    \"\"\"\n",
    "    H = torch.zeros_like(conv.weight)\n",
    "    s1 = torch.sigmoid(output_tensor).gt_(0.5)\n",
    "    for idx, sconv in zip(single_unit_indxs, single_unit_convs):\n",
    "\n",
    "        s2 = torch.sigmoid(sconv(input_tensor)).gt_(0.5)\n",
    "        m = torch.sum(s1.mul(s2), (0, 2, 3,))\n",
    "\n",
    "        H[:, idx[0], idx[1], idx[2]] += m\n",
    "        \n",
    "    return H"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Vectorized Method + Grouping\n",
    "Same as the previous method, but utilizing the grouping argument of the conv layer\n",
    "so that only one is needed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_single_unit_weights(shape, c, j, h, **kwargs):\n",
    "    \"\"\"\n",
    "    Constructs and returns conv layer with traingin diabled and\n",
    "    all zero weights except along the output channels for unit\n",
    "    specified as (c, j, h).\n",
    "    \"\"\"\n",
    "    \n",
    "    # Construct weight.\n",
    "    weight = torch.zeros(shape)\n",
    "    \n",
    "    # Set weights to zero except those specified.\n",
    "    weight[:, c, j, h] = 1\n",
    "        \n",
    "    return weight\n",
    "\n",
    "# Compute inidices that loop over all connections of a channel.\n",
    "filter_indxs = list(itertools.product(*[range(d) for d in conv.weight.shape[1:]]))\n",
    "\n",
    "# Compute indeces that loop over all channels and filters.\n",
    "# This will be used to unpack the pointwise comparisons of the output.\n",
    "connection_indxs = []\n",
    "for idx in filter_indxs:\n",
    "    i_ = list(idx)\n",
    "    connection_indxs.extend([\n",
    "        [c_]+i_ for c_ in range(out_channels)\n",
    "    ])\n",
    "connection_indxs = list(zip(*connection_indxs))\n",
    "\n",
    "# Create new conv layer that groups it's input and output.\n",
    "new_groups = len(filter_indxs)\n",
    "stacked_conv = torch.nn.Conv2d(\n",
    "    in_channels=in_channels * new_groups,\n",
    "    out_channels=out_channels * new_groups,\n",
    "    kernel_size=kernel_size,\n",
    "    stride=stride,\n",
    "    padding=padding,\n",
    "    padding_mode=padding_mode,\n",
    "    dilation=dilation,\n",
    "    groups=groups * new_groups,\n",
    "    bias=False,\n",
    ")\n",
    "\n",
    "# Populate the weight matrix with stacked tensors having only one non-zero unit.\n",
    "single_unit_weights = [\n",
    "    get_single_unit_weights(\n",
    "        conv.weight.shape,\n",
    "        c, j, h,\n",
    "    )\n",
    "    for c, j, h in filter_indxs\n",
    "]\n",
    "with torch.no_grad():\n",
    "    stacked_conv.weight.set_(torch.cat(single_unit_weights, dim=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f2():\n",
    "#     print('------f2a--------')\n",
    "#     print('input ', input_tensor.shape)\n",
    "    stacked_input = input_tensor.repeat((1, new_groups, 1, 1))\n",
    "    stacked_output = stacked_conv(stacked_input)\n",
    "    \n",
    "#     print('stacked_input ', stacked_input.shape)\n",
    "#     print('stacked_output', stacked_output.shape)\n",
    "    \n",
    "    H = torch.zeros_like(conv.weight)\n",
    "\n",
    "    s1 = torch.sigmoid(stacked_output).gt_(0.5)\n",
    "    s2 = torch.sigmoid(output_tensor).gt_(0.5).repeat((1, new_groups, 1, 1))\n",
    "    \n",
    "    print('s1', s1.shape)\n",
    "    print('s2', s2.shape)\n",
    "    \n",
    "    H_ = torch.sum(s2.mul(s1), (0, 2, 3,))\n",
    "    \n",
    "#     print('H_', H_.shape)\n",
    "\n",
    "    H[connection_indxs] = H_\n",
    "    \n",
    "#     print('\\n')\n",
    "    return H"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Vectorized Method with Less Redundancies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_single_unit_weights_2b(shape, c, j, h, **kwargs):\n",
    "    \"\"\"\n",
    "    Constructs and returns conv layer with traingin diabled and\n",
    "    all zero weights except along the output channels for unit\n",
    "    specified as (c, j, h).\n",
    "    \"\"\"\n",
    "    \n",
    "    # Construct weight.\n",
    "    weight = torch.zeros(1, *shape[1:])\n",
    "    \n",
    "    # Set weights to zero except those specified.\n",
    "    weight[0, c, j, h] = 1\n",
    "        \n",
    "    return weight\n",
    "\n",
    "# Compute inidices that loop over all connections of a channel.\n",
    "filter_indxs_2b = list(itertools.product(*[range(d) for d in conv.weight.shape[1:]]))\n",
    "\n",
    "# Compute indeces that loop over all channels and filters.\n",
    "# This will be used to unpack the pointwise comparisons of the output.\n",
    "connection_indxs_2b = []\n",
    "for c_ in range(out_channels):\n",
    "    for idx in filter_indxs_2b:\n",
    "        i_ = list(idx)\n",
    "        connection_indxs_2b.append([c_] + i_)\n",
    "connection_indxs_2b = list(zip(*connection_indxs_2b))\n",
    "\n",
    "new_groups_2b = int(np.prod(conv.weight.shape[1:]))\n",
    "perm_indices_2b = []\n",
    "for c_i in range(out_channels):\n",
    "    perm_indices_2b.extend(\n",
    "        [c_i] * new_groups_2b\n",
    "    )\n",
    "\n",
    "# Create new conv layer that groups it's input and output.\n",
    "stacked_conv_2b = torch.nn.Conv2d(\n",
    "    in_channels=in_channels * new_groups_2b,\n",
    "    out_channels=new_groups_2b,\n",
    "    kernel_size=kernel_size,\n",
    "    stride=stride,\n",
    "    padding=padding,\n",
    "    padding_mode=padding_mode,\n",
    "    dilation=dilation,\n",
    "    groups=groups * new_groups_2b,\n",
    "    bias=False,\n",
    ")\n",
    "# Populate the weight matrix with stacked tensors having only one non-zero unit.\n",
    "single_unit_weights_2b = [\n",
    "    get_single_unit_weights_2b(\n",
    "        conv.weight.shape,\n",
    "        c, j, h,\n",
    "    )\n",
    "    for c, j, h in filter_indxs_2b\n",
    "]\n",
    "with torch.no_grad():\n",
    "    stacked_conv_2b.weight.set_(torch.cat(single_unit_weights_2b, dim=0))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f2b():\n",
    "    \n",
    "    stacked_input = input_tensor.repeat((1, new_groups_2b, 1, 1))\n",
    "    stacked_output = stacked_conv_2b(stacked_input).repeat((1, out_channels, 1, 1))\n",
    "    \n",
    "\n",
    "    H = torch.zeros_like(conv.weight)\n",
    "\n",
    "    s1 = stacked_output\n",
    "    s2 = output_tensor[:, perm_indices_2b, ...]\n",
    "\n",
    "    mu_in = s1.mean(dim=0)\n",
    "    mu_out = s2.mean(dim=0)\n",
    "\n",
    "    std_in = s1.std(dim=0)\n",
    "    std_out = s2.std(dim=0)\n",
    "    \n",
    "    corr = ((s1 - mu_in) * (s2 - mu_out)).mean(dim=0) / (std_in * std_out)\n",
    "    \n",
    "    corr[torch.where((std_in == 0 ) | (std_out == 0 ))] = 0\n",
    "    corr = corr.abs()\n",
    "    \n",
    "    H_ = torch.mean(corr, (1, 2))\n",
    "\n",
    "    H[connection_indxs_2b] = H_\n",
    "    \n",
    "    return H"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Brute Force Method"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Computationally speaking, this is the same method as the preivous two. Only now, instead of using conv layers to assist in the computations, we use `for` loops the brute force our way through. \n",
    "\n",
    "This is more so for a sanity check on the first two to validate their outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def coactivation(t1, t2):\n",
    "    s = (torch.sigmoid(t1) > 0.5) * (torch.sigmoid(t2) > 0.5)\n",
    "    return s \n",
    "    \n",
    "def get_indeces_of_input_and_filter(n, m):\n",
    "    \"\"\"\n",
    "    Assumes dilation=1, i.e. typical conv.\n",
    "    \"\"\"\n",
    "    \n",
    "    k1, k2 = kernel_size\n",
    "    p1, p2 = padding\n",
    "    s1, s2 = stride\n",
    "    \n",
    "    i1, i2 = (0, 0)\n",
    "    \n",
    "    i1 -= p1\n",
    "    i2 -= p2\n",
    "    \n",
    "    i1 += n * s1\n",
    "    i2 += m * s2\n",
    "    \n",
    "    if i2 == 2:\n",
    "        import ipdb; ipdb.set_trace()\n",
    "    \n",
    "    indxs = []\n",
    "    for c_in in range(in_channels):\n",
    "        for n_k1 in range(k1):\n",
    "            for m_k2 in range(k2):\n",
    "                filter_indx = (c_in,      n_k1,      m_k2)\n",
    "                input_indx  = (c_in, i1 + n_k1, i2 + m_k2)\n",
    "                indxs.append((input_indx, filter_indx))\n",
    "                \n",
    "    return indxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "B     = output_tensor.shape[0]\n",
    "N_out = output_tensor.shape[2]\n",
    "M_out = output_tensor.shape[3]\n",
    "C_in  = conv.weight.shape[1]\n",
    "C_out = conv.weight.shape[0]\n",
    "\n",
    "def f3():\n",
    "    H = torch.zeros_like(conv.weight)\n",
    "    for b in range(B):\n",
    "        for c_out in range(C_out):\n",
    "            for n_out in range(N_out):\n",
    "                for m_out in range(M_out):\n",
    "                    unit_1 = output_tensor[b, c_out, n_out, m_out]\n",
    "                    indxs  = get_indeces_of_input_and_filter(n_out, m_out)\n",
    "\n",
    "                    for input_indx, filter_indx in indxs:\n",
    "                        c_in, n_in, m_in = input_indx\n",
    "                        c_fl, n_fl, m_fl = filter_indx\n",
    "                        unit_2 = input_tensor[b, c_in, n_in, m_in]\n",
    "\n",
    "                        if coactivation(unit_1, unit_2):\n",
    "                            H[c_out, c_fl, n_fl, m_fl] += 1\n",
    "                            \n",
    "    return H\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Validation \\ Time Trials"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Quick test to make sure they all give the same output.\n",
    "Let's see how long the take to run."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9.64 ms ± 162 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
      "326 µs ± 6.49 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n",
      "348 µs ± 5.33 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
     ]
    }
   ],
   "source": [
    "assert f1().allclose(f2(), rtol=0, atol=0) and f2().allclose(f2b(), rtol=0, atol=0) and f2b().allclose(f3(), rtol=0, atol=0) \n",
    "%timeit f1()\n",
    "%timeit f2()\n",
    "%timeit f2b()\n",
    "# %timeit f3()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sparse Conv Layer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now to implement a conv layer that utilizes the second implementation above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DSConv2d(torch.nn.Conv2d):\n",
    "    \n",
    "    def __init__(self, *args, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "\n",
    "        self.learning_iterations = 0\n",
    "        \n",
    "        self.activity_threshold = 0.5\n",
    "        self.k1                 = max(int(0.1 * np.prod(self.weight.shape[2:])), 1)\n",
    "        self.k2                 = max(int(0.15 * np.prod(self.weight.shape[2:])), 1)\n",
    "        self.prune_dims         = [0, 1] # Todo: sort\n",
    "        \n",
    "        self.connections_tensor = torch.zeros_like(self.weight)\n",
    "        self.prune_mask = torch.ones_like(self.weight)\n",
    "        \n",
    "        # Compute inidices that loop over all connections of a channel.\n",
    "        filter_indxs = list(itertools.product(*[range(d) for d in self.weight.shape[1:]]))\n",
    "\n",
    "        # Compute indeces that loop over all channels and filters.\n",
    "        # This will be used to unpack the pointwise comparisons of the output.\n",
    "        self.connection_indxs = []\n",
    "        for idx in filter_indxs:\n",
    "            i_ = list(idx)\n",
    "            self.connection_indxs.extend([\n",
    "                [c]+i_ for c in range(self.weight.shape[0])\n",
    "            ])\n",
    "        self.connection_indxs = list(zip(*self.connection_indxs))\n",
    "\n",
    "        # Create new conv layer that groups it's input and output.\n",
    "        self.new_groups = len(filter_indxs)\n",
    "        self.stacked_conv = torch.nn.Conv2d(\n",
    "            in_channels=self.in_channels * self.new_groups,\n",
    "            out_channels=self.out_channels * self.new_groups,\n",
    "            kernel_size=self.kernel_size,\n",
    "            stride=self.stride,\n",
    "            padding=self.padding,\n",
    "            padding_mode=self.padding_mode,\n",
    "            dilation=self.dilation,\n",
    "            groups=self.groups * self.new_groups,\n",
    "            bias=False,\n",
    "        )\n",
    "\n",
    "        # Populate the weight matrix with stacked tensors having only one non-zero unit.\n",
    "        single_unit_weights = [\n",
    "            self.get_single_unit_weights(\n",
    "                self.weight.shape,\n",
    "                c, j, h,\n",
    "            )\n",
    "            for c, j, h in filter_indxs\n",
    "        ]\n",
    "        with torch.no_grad():\n",
    "            self.stacked_conv.weight.set_(torch.cat(single_unit_weights, dim=0))    \n",
    "    \n",
    "    def get_single_unit_weights(self, shape, c, j, h):\n",
    "        \"\"\"\n",
    "        Constructs and returns conv layer with traingin diabled and\n",
    "        all zero weights except along the output channels for unit\n",
    "        specified as (c, j, h).\n",
    "        \"\"\"\n",
    "\n",
    "        # Construct weight.\n",
    "        weight = torch.zeros(self.weight.shape)\n",
    "\n",
    "        # Set weights to zero except those specified.\n",
    "        weight[:, c, j, h] = 1\n",
    "\n",
    "        return weight\n",
    "    \n",
    "    def update_connections_tensor(self, input_tensor, output_tensor):\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            stacked_input = input_tensor.repeat((1, self.new_groups, 1, 1))\n",
    "            stacked_output = self.stacked_conv(stacked_input)\n",
    "\n",
    "            s1 = torch.sigmoid(stacked_output).gt_(0.5)\n",
    "            s2 = torch.sigmoid(output_tensor).gt_(0.5).repeat((1, self.new_groups, 1, 1))\n",
    "            H_ = torch.sum(s2.mul(s1), (0, 2, 3,))\n",
    "\n",
    "            self.connections_tensor[self.connection_indxs] = H_\n",
    "    \n",
    "    def progress_connections(self):\n",
    "        \"\"\"\n",
    "        Prunes and add connections.\n",
    "        \"\"\"\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            \n",
    "            # Get strengths of all connections.\n",
    "            strengths = self.connections_tensor.numpy()\n",
    "            shape = strengths.shape\n",
    "\n",
    "            # Determine all combinations of prune dimensions\n",
    "            all_dims = range(len(shape))\n",
    "            prune_indxs = [range(shape[d]) if d in self.prune_dims else [slice(None)] for d in all_dims]\n",
    "            prune_indxs = itertools.product(*prune_indxs)\n",
    "\n",
    "            # Along all combinations of prune dimensions:\n",
    "            #    - Keep strongest k1 connections\n",
    "            #    - Reinitilize trailing k2 - k1 connections.\n",
    "            k1 = self.k1\n",
    "            k2 = self.k2\n",
    "            for idx in prune_indxs:\n",
    "\n",
    "                # Get top k1'th strength.\n",
    "                s = strengths[idx].flatten()\n",
    "                v1 = np.partition(s, -k1)[-k1] # s.kthvalue(len(s) - k1).value\n",
    "\n",
    "                # Keep top k1'th connection - prune those below\n",
    "                c = self.weight[idx].flatten()\n",
    "                prune_mask = (s < v1).astype(np.uint8)\n",
    "                c[prune_mask] = 0\n",
    "\n",
    "                # Get trailing k2 - k1 connections.\n",
    "                v2 = np.partition(s, -k2)[-k2] # s.kthvalue(len(s) - k2).value\n",
    "                new_mask = (s > v2) & prune_mask\n",
    "\n",
    "                # Reinitilized trailing k2 - k1 connections.\n",
    "                # Note: [None, :] is added here as kaiming_uniform_ requires a 2d tensor.\n",
    "                if len(c[new_mask]) > 0:\n",
    "                    torch.nn.init.kaiming_uniform_(c[new_mask][None, :])\n",
    "\n",
    "                # Reshape connections and update the weight.\n",
    "                self.weight[idx] = c.reshape(self.weight[idx].shape)\n",
    "                \n",
    "                self.prune_mask = prune_mask\n",
    "\n",
    "            # Reset connection strengths.\n",
    "            self.connections_tensor = torch.zeros_like(self.weight)\n",
    "            \n",
    "    def prune_randomly(self):\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            \n",
    "            prune_mask = torch.rand(self.weight.shape) < 0.85 # prune 15% of weights\n",
    "            self.weight[prune_mask] = 0\n",
    "            \n",
    "            # Reinitialize those that are zero.\n",
    "            keep_mask = ~prune_mask\n",
    "            new_mask  = (self.weight == 0) & keep_mask\n",
    "            new_weights = self.weight[new_mask]\n",
    "            if len(new_weights) > 0:\n",
    "                torch.nn.init.kaiming_uniform_(new_weights[None, :])\n",
    "                self.weight[new_mask] = new_weights\n",
    "       \n",
    "    def __call__(self, input_tensor, *args, **kwargs):\n",
    "        output_tensor = super().__call__(input_tensor, *args, **kwargs)\n",
    "        if self.learning_iterations % 20 == 0:\n",
    "            self.update_connections_tensor(input_tensor, output_tensor)\n",
    "        self.learning_iterations += 1\n",
    "        return output_tensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test Training a Network\n",
    "\n",
    "The following is a simple toy example copied mostly from:\n",
    "https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html\n",
    "\n",
    "The main difference here is that the network utilizes the SparseCNN module.\n",
    "This exercise moslty servesto gain confidence in the implementation with respect to it's ability to run without errors - this is not concerned with verifying training improvements just yet. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "root_path = '~/nta/datasets'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "transform = transforms.Compose(\n",
    "    [transforms.ToTensor(),\n",
    "     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
    "\n",
    "trainset = torchvision.datasets.CIFAR10(root=root_path, train=True,\n",
    "                                        download=True, transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,\n",
    "                                          shuffle=True, num_workers=2)\n",
    "\n",
    "testset = torchvision.datasets.CIFAR10(root=root_path, train=False,\n",
    "                                       download=True, transform=transform)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=4,\n",
    "                                         shuffle=False, num_workers=0)\n",
    "\n",
    "classes = ('plane', 'car', 'bird', 'cat',\n",
    "           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Time Trials\n",
    "Quick test to compare runtime with and without updating the connections tensor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dense CNN foward pass:\n",
      "215 µs ± 3.04 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n",
      "DSConv2d foward pass:\n",
      "3.59 ms ± 37.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
     ]
    }
   ],
   "source": [
    "cd = torch.nn.Conv2d(3, 6, 5)\n",
    "cs = DSConv2d(3, 6, 5)\n",
    "\n",
    "dataiter = iter(trainloader)\n",
    "images, labels = dataiter.next()\n",
    "\n",
    "print('Dense CNN foward pass:')\n",
    "%timeit cd(images)\n",
    "print('DSConv2d foward pass:')\n",
    "%timeit cs(images)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Network Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 278,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        # self.conv1 = nn.Conv2d(3, 6, 5)\n",
    "        self.conv1 = DSConv2d(3, 6, 5)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        # self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        self.conv2 = DSConv2d(6, 16, 5)\n",
    "        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = x.view(-1, 16 * 5 * 5)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "    \n",
    "net = Net()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 279,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1,  2000] loss: 2.221\n",
      "Finished Training\n"
     ]
    }
   ],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)\n",
    "\n",
    "for epoch in range(2):  # loop over the dataset multiple times\n",
    "\n",
    "    running_loss = 0.0\n",
    "    for i, data in enumerate(trainloader, 0):\n",
    "\n",
    "        # get the inputs; data is a list of [inputs, labels]\n",
    "        inputs, labels = data\n",
    "\n",
    "        # zero the parameter gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # forward + backward + optimize\n",
    "        outputs = net(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # print statistics\n",
    "        running_loss += loss.item()\n",
    "        if i % 2000 == 1999:    # print every 2000 mini-batches\n",
    "            print('[%d, %5d] loss: %.3f' %\n",
    "                  (epoch + 1, i + 1, running_loss / 2000))\n",
    "            running_loss = 0.0\n",
    "            \n",
    "            net.conv1.progress_connections()\n",
    "            net.conv2.progress_connections()\n",
    "    \n",
    "            break\n",
    "        \n",
    "    break\n",
    "    \n",
    "#     # Compare with pruning random weights.\n",
    "#     net.conv1.prune_randomly()\n",
    "#     net.conv2.prune_randomly() \n",
    "\n",
    "print('Finished Training')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Testing Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of plane :  0 %\n",
      "Accuracy of   car :  0 %\n",
      "Accuracy of  bird : 79 %\n",
      "Accuracy of   cat :  2 %\n",
      "Accuracy of  deer : 10 %\n",
      "Accuracy of   dog :  0 %\n",
      "Accuracy of  frog :  1 %\n",
      "Accuracy of horse :  0 %\n",
      "Accuracy of  ship :  0 %\n",
      "Accuracy of truck :  0 %\n"
     ]
    }
   ],
   "source": [
    "class_correct = list(0. for i in range(10))\n",
    "class_total = list(0. for i in range(10))\n",
    "with torch.no_grad():\n",
    "    for data in testloader:\n",
    "        images, labels = data\n",
    "        outputs = net(images)\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "        c = (predicted == labels).squeeze()\n",
    "        for i in range(4):\n",
    "            label = labels[i]\n",
    "            class_correct[label] += c[i].item()\n",
    "            class_total[label] += 1\n",
    "\n",
    "\n",
    "for i in range(10):\n",
    "    print('Accuracy of %5s : %2d %%' % (\n",
    "        classes[i], 100 * class_correct[i] / class_total[i]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
